#
#  样本处理的相关的库
# 

from ImageHandle import ImageHandler
import math
import numpy as np
import random
import scipy
# 最小二乘求解非线性方程组
from scipy.optimize import leastsq,fsolve,root
from osgeo import gdal,gdalconst
import pandas as pds
from scipy import interpolate
from multiprocessing import pool
# 常量声明区域
imageHandler=ImageHandler()


# python 的函数类
def read_sample_csv(csv_path):
    """ 读取样本的csv
    Args:
        csv_path (string): 样本csv的地址，绝对路径
    return:
        [ 
           ['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],
          ['日期','样方编号','经度','纬度','叶面积指数',"后向散射系数"],......
        ]
    """
    lai_csv=pds.read_csv(csv_path)# 代码测试区域
    lai_csv=lai_csv.loc[:,['样本号','经度','纬度','叶面积指数',"后向散射系数"]]
    result=[]
    for i in range(len(lai_csv)):
        result.append([
            0,
            lai_csv.loc[i,'样本号'],
            lai_csv.loc[i,'经度'], # lon,x
            lai_csv.loc[i,'纬度'], # lat,y
            lai_csv.loc[i,'叶面积指数'],
            10**(float(lai_csv.loc[i,'后向散射系数'])/10),
        ])
    return result
def read_tiff(tiff_path):
    """ 从文件中读取影像

    Args:
        tiff_path (string): 文件影像路径
    """
    im_proj, im_geotrans, im_arr=imageHandler.read_img(tiff_path)
    return {
        'proj':im_proj,
        'geotrans':im_geotrans,
        'data':im_arr
    }
    
def ReprojectImages2(in_tiff_path,ref_tiff_path,out_tiff_path,resampleAlg=gdalconst.GRA_Bilinear):
    """ 将输入影像重采样到参考影像的范围内

    Args:
        in_tiff_path (string): 输入影像
        ref_tiff_path (string): 参考影像
        out_tiff_path (string): 输出地址
        resampleAlg (gadlconst): 插值方法
    """
    # 若采用gdal.Warp()方法进行重采样
    # 获取输出影像信息
    inputrasfile = gdal.Open(in_tiff_path, gdal.GA_ReadOnly)
    inputProj = inputrasfile.GetProjection()
    # 获取参考影像信息
    referencefile = gdal.Open(ref_tiff_path, gdal.GA_ReadOnly)
    referencefileProj = referencefile.GetProjection()
    referencefileTrans = referencefile.GetGeoTransform()
    bandreferencefile = referencefile.GetRasterBand(1)
    x = referencefile.RasterXSize
    y = referencefile.RasterYSize
    nbands = referencefile.RasterCount
    # 创建重采样输出文件（设置投影及六参数）
    driver = gdal.GetDriverByName('GTiff')
    output = driver.Create(out_tiff_path, x, y, nbands, bandreferencefile.DataType)
    output.SetGeoTransform(referencefileTrans)
    output.SetProjection(referencefileProj)
    options = gdal.WarpOptions(srcSRS=inputProj, dstSRS=referencefileProj, resampleAlg=gdalconst.GRA_Bilinear)
    gdal.Warp(output, in_tiff_path, options=options)

def combine_sample_attr(sample_list,attr_tiff):
    """ 构建样本

    Args:
        sample_list (list): 原样本
        attr_tiff (string): 添加的属性数据

    Returns:
        list:[sample,new_attr] 
    """
    result=[]
    # 因为soil_tiff 的影像的 影像分辨率较低
    inv_gt=gdal.InvGeoTransform(attr_tiff['geotrans'])
    for sample_item in sample_list:
        sample_lon=sample_item[2]
        sample_lat=sample_item[3]
        sample_in_tiff_x=inv_gt[0]+inv_gt[1]*sample_lon+inv_gt[2]*sample_lat # x
        sample_in_tiff_y=inv_gt[3]+inv_gt[4]*sample_lon+inv_gt[5]*sample_lat # y
        x_min=int(np.floor(sample_in_tiff_x))
        x_max=int(np.ceil(sample_in_tiff_x))
        y_min=int(np.floor(sample_in_tiff_y))
        y_max=int(np.ceil(sample_in_tiff_y))
        if x_min<0 or y_min<0 or x_max>=attr_tiff['data'].shape[1] or y_max>=attr_tiff['data'].shape[0]:
            continue
        # 
        """
        f = interpolate.interp2d([0,0,1,1], [0,1,1,0], 
                                 [attr_tiff['data'][y_min,x_min],
                                  attr_tiff['data'][y_max,x_min],
                                  attr_tiff['data'][y_max,x_max],
                                  attr_tiff['data'][y_min,x_min]
                                  ], kind='linear')
        interp_value=f(sample_in_tiff_x-x_min,sample_in_tiff_y-y_min)
        sample_item.append(interp_value[0])
        """
        # 9x9
        x_min=x_min-4 if x_min-9>=0 else 0
        y_min=y_min-4 if y_min-9>=0 else 0
        x_max=x_max+4 if x_max+4<attr_tiff['data'].shape[1] else attr_tiff['data'].shape[1]
        y_max=y_max+4 if y_max+4<attr_tiff['data'].shape[0] else attr_tiff['data'].shape[0]
        interp_value=np.mean(attr_tiff['data'][y_min:y_max,x_min:x_max])
        sample_item.append(interp_value)
        result.append(sample_item)
    return result

def check_sample(sample_list):
    """ 检查样本值

    Args:
        sample_list (list): 样本值：[ ['日期', '样方编号', '经度', '纬度', 'LAI','土壤含水量','入射角','后向散射系数'] ]

    Returns:
        list : 处理之后的样本值
    """
    result=[]
    for item in sample_list:
        if len(item)==10:
            sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
        else:
            sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item
        if sample_sigma<=0:
            continue
        if (sample_inc*180/np.pi)>90:
            continue
        if sample_soil<=0 or sample_soil>=1:
            continue
        if sample_lai<=0 or sample_lai>=20:
            continue
        result.append(item)
    # 绘制分布图
    lai=[]
    sigma=[]
    csv_sigmas=[]
    text_label=[]
    for item in result:
        if len(item)==10:
            sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma,sample_NDVI=item
        else:
            sample_time,sample_code,sample_lon,sample_lat,sample_lai,csv_sigma,sample_soil,sample_inc,sample_sigma=item 
        text_label.append(sample_code)
        lai.append(sample_lai)
        sigma.append(sample_sigma)
        csv_sigmas.append(csv_sigma)
    from matplotlib import pyplot as plt
    plt.scatter(np.array(lai),np.array(sigma),label="lai-tiff_sigma")
    for i in range(len(sigma)):
        plt.annotate(text_label[i], xy = (lai[i], sigma[i])) # 这里xy是需要标记的坐标，xytext是对应的标签坐标
    
    plt.scatter(np.array(lai),np.array(csv_sigmas),label="lai-csv_sigmas")
    for i in range(len(csv_sigmas)):
        plt.annotate(text_label[i], xy = (lai[i],csv_sigmas[i])) # 这里xy是需要标记的坐标，xytext是对应的标签坐标
    plt.legend()
    plt.show()
    
    return result



def split_sample_list(sample_list,train_ratio):
    """ 切分样本比值

    Args:
        sample_list (list): 样本列表
        train_ratio (double): 训练样本的比重

    Returns:
        list: [sample_train,sample_test]
    """
    sample_train=[]
    sample_test=[]
    n=len(sample_list)
    for i in range(n):
        if random.random()<=train_ratio:
            sample_train.append(sample_list[i])
        else:
            sample_test.append(sample_list[i])
    return  [sample_train,sample_test]